Skip to content

Conversation

@jpdunc23
Copy link
Member

@jpdunc23 jpdunc23 commented Feb 9, 2026

Adds train_stepper: CoupledTrainStepperConfig to the coupled training config, which configures and builds a CoupledTrainStepper implementing TrainStepperABC.

WARNING: This is a breaking change for existing coupled training configs.

Changes:

  • Component stepper loss: StepLossConfig and loss_contributions: LossContributionsConfig are now configured via the ocean: ComponentTrainingConfig and atmosphere: ComponentTrainingConfig attributes of CoupledTrainStepperConfig.

  • CoupledStepper no longer implements TrainStepperABC.

  • Removed public loss_obj and effective_loss_scaling properties from fme.ace.stepper.Stepper and added a new public method build_loss.

  • Tests added

) -> StepLossABC:
if self.n_steps == 0 or self.weight == 0.0:
return NullLossContributions()
return NullLossContributions(loss_obj)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This preserves the existing behavior where we used a component stepper's effective_loss_scaling to compute mse_fractional_components metrics even if the stepper had no loss contribution in coupled training.

@jpdunc23 jpdunc23 marked this pull request as ready for review February 10, 2026 16:53
Copy link
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some nits (nits are optional), I don't need to re-review them. LGTM


@property
def effective_loss_scaling(self):
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise NotImplementedError
raise NotImplementedError()

atmos_loss_config = LossContributionsConfig()
atmosphere_loss = atmos_loss_config.build(
loss_obj=lambda *_, **__: torch.tensor(5.25),
loss_obj=Mock(spec=StepLoss, side_effect=lambda *_, **__: torch.tensor(5.25)),
Copy link
Contributor

@mcgibbon mcgibbon Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The three lines changed use three different ways to specify the loss side-effect - via mae_loss, via a lambda function returning a constant, and via a return_value instead of a side_effect. You could consider using return_value for this one to reduce that down to 2 ways, at least.

n_samples=3,
)
output = coupler.train_on_batch(
train_stepper_config = CoupledTrainStepperConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Avoid the 3x copy-paste of this process by making a get_train_stepper_and_batch helper that does it and calls get_stepper_and_batch internally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, but I'll defer this cleanup to #814 since the way in which the train stepper is built is going to change.



@dataclasses.dataclass
class CoupledTrainStepperConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example of the updated training config committed somewhere I could check out? It would be nice to have a baseline config for coupled training, if so I could see the changes to the baseline in this PR.

Update: Ah I see test_train.py mostly fits this purpose, good. Still, could be nice to have a baseline in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I will work on a new PR to add the baseline.

loss_contributions:
n_steps: {loss_atmos_n_steps}
stepper:
loss:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why is loss: type: MSE in both the atmosphere: stepper: and in the train_stepper: atmosphere:? I am guessing because we haven't updated the ACE configs yet and it's required in the config, in which case that's fine, but I though I should ask to be sure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, although including it in the yaml here isn't strictly necessary since there is a default value on StepperConfig. I'll remove it so it's a bit clearer here.

atmosphere_normalize=stepper.atmosphere.normalizer.normalize,
ocean_loss_scaling=stepper.ocean.effective_loss_scaling,
atmosphere_loss_scaling=stepper.atmosphere.effective_loss_scaling,
ocean_loss_scaling=stepper.effective_loss_scaling.ocean,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: pass loss_scaling: stepper.loss_scaling instead of two arguments containing the parts

@jpdunc23 jpdunc23 enabled auto-merge (squash) February 10, 2026 21:27
@jpdunc23 jpdunc23 merged commit aea4317 into main Feb 11, 2026
7 checks passed
@jpdunc23 jpdunc23 deleted the refactor/coupled-train-stepper branch February 11, 2026 07:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants